{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "71bf747a",
   "metadata": {},
   "source": [
    "# Introduction to Taxi ETL Job\n",
    "This is the Taxi ETL job to generate the input datasets for the Taxi XGBoost job."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0524408",
   "metadata": {},
   "source": [
    "## Prerequirement\n",
    "### 1. Download data\n",
    "All data could be found at https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page\n",
    "\n",
    "### 2. Download needed jars\n",
    "* [rapids-4-spark_2.12-25.04.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/25.04.0/rapids-4-spark_2.12-25.04.0.jar)\n",
    "\n",
    "### 3. Start Spark Standalone\n",
    "Before running the script, please setup Spark standalone mode\n",
    "\n",
    "### 4. Add ENV\n",
    "```\n",
    "$ export SPARK_JARS=rapids-4-spark_2.12-25.04.0.jar\n",
    "$ export PYSPARK_DRIVER_PYTHON=jupyter \n",
    "$ export PYSPARK_DRIVER_PYTHON_OPTS=notebook\n",
    "```\n",
    "\n",
    "### 5. Start Jupyter Notebook with plugin config\n",
    "\n",
    "```\n",
    "$ pyspark --master ${SPARK_MASTER}            \\\n",
    "--jars ${SPARK_JARS}                \\\n",
    "--conf spark.plugins=com.nvidia.spark.SQLPlugin \\\n",
    "--conf spark.rapids.sql.incompatibleDateFormats.enabled=true \\\n",
    "--conf spark.rapids.sql.csv.read.double.enabled=true \\\n",
    "--py-files ${SPARK_PY_FILES}\n",
    "```\n",
    "\n",
    "## Import Libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d2283aab",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import os\n",
    "import math\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark.sql.functions import *\n",
    "from pyspark.sql.types import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7ffcace",
   "metadata": {},
   "source": [
    "## Script Settings\n",
    "\n",
    "###  File Path Settings\n",
    "* Define input/output file path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b348778a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# You need to update them to your real paths! You can download the dataset \n",
    "# from https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page\n",
    "# or you can just unzip datasets/taxi-small.tar.gz and use the provided\n",
    "# sample dataset datasets/taxi/taxi-etl-input-small.csv\n",
    "dataRoot = os.getenv('DATA_ROOT', '/data')\n",
    "rawPath = dataRoot + '/taxi/taxi-etl-input-small.csv'\n",
    "outPath = dataRoot + '/taxi/output'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a500530",
   "metadata": {},
   "source": [
    "## Function and Object Define\n",
    "### Define the constants\n",
    "\n",
    "* Define input file schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "094f31c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_schema = StructType([\n",
    "    StructField('vendor_id', StringType()),\n",
    "    StructField('pickup_datetime', StringType()),\n",
    "    StructField('dropoff_datetime', StringType()),\n",
    "    StructField('passenger_count', IntegerType()),\n",
    "    StructField('trip_distance', DoubleType()),\n",
    "    StructField('pickup_longitude', DoubleType()),\n",
    "    StructField('pickup_latitude', DoubleType()),\n",
    "    StructField('rate_code', StringType()),\n",
    "    StructField('store_and_fwd_flag', StringType()),\n",
    "    StructField('dropoff_longitude', DoubleType()),\n",
    "    StructField('dropoff_latitude', DoubleType()),\n",
    "    StructField('payment_type', StringType()),\n",
    "    StructField('fare_amount', DoubleType()),\n",
    "    StructField('surcharge', DoubleType()),\n",
    "    StructField('mta_tax', DoubleType()),\n",
    "    StructField('tip_amount', DoubleType()),\n",
    "    StructField('tolls_amount', DoubleType()),\n",
    "    StructField('total_amount', DoubleType()),\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72a4ae18",
   "metadata": {},
   "source": [
    "* Define some ETL functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b45b7606",
   "metadata": {},
   "outputs": [],
   "source": [
    "def drop_useless(data_frame):\n",
    "    return data_frame.drop(\n",
    "        'dropoff_datetime',\n",
    "        'payment_type',\n",
    "        'surcharge',\n",
    "        'mta_tax',\n",
    "        'tip_amount',\n",
    "        'tolls_amount',\n",
    "        'total_amount')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7af7073d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_categories(data_frame):\n",
    "    categories = [ 'vendor_id', 'rate_code', 'store_and_fwd_flag' ]\n",
    "    for category in categories:\n",
    "        data_frame = data_frame.withColumn(category, hash(col(category)))\n",
    "    return data_frame.withColumnRenamed(\"store_and_fwd_flag\", \"store_and_fwd\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b799cd5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fill_na(data_frame):\n",
    "    return data_frame.fillna(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ceee5c7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_invalid(data_frame):\n",
    "    conditions = [\n",
    "        ( 'fare_amount', 0, 500 ),\n",
    "        ( 'passenger_count', 0, 6 ),\n",
    "        ( 'pickup_longitude', -75, -73 ),\n",
    "        ( 'dropoff_longitude', -75, -73 ),\n",
    "        ( 'pickup_latitude', 40, 42 ),\n",
    "        ( 'dropoff_latitude', 40, 42 ),\n",
    "    ]\n",
    "    for column, min, max in conditions:\n",
    "        data_frame = data_frame.filter('{} > {} and {} < {}'.format(column, min, column, max))\n",
    "    return data_frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bd28ae14",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_datetime(data_frame):\n",
    "    datetime = col('pickup_datetime')\n",
    "    return (data_frame\n",
    "        .withColumn('pickup_datetime', to_timestamp(datetime))\n",
    "        .withColumn('year', year(datetime))\n",
    "        .withColumn('month', month(datetime))\n",
    "        .withColumn('day', dayofmonth(datetime))\n",
    "        .withColumn('day_of_week', dayofweek(datetime))\n",
    "        .withColumn(\n",
    "            'is_weekend',\n",
    "            col('day_of_week').isin(1, 7).cast(IntegerType()))  # 1: Sunday, 7: Saturday\n",
    "        .withColumn('hour', hour(datetime))\n",
    "        .drop('pickup_datetime'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "39e45f15",
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_h_distance(data_frame):\n",
    "    p = math.pi / 180\n",
    "    lat1 = col('pickup_latitude')\n",
    "    lon1 = col('pickup_longitude')\n",
    "    lat2 = col('dropoff_latitude')\n",
    "    lon2 = col('dropoff_longitude')\n",
    "    internal_value = (0.5\n",
    "        - cos((lat2 - lat1) * p) / 2\n",
    "        + cos(lat1 * p) * cos(lat2 * p) * (1 - cos((lon2 - lon1) * p)) / 2)\n",
    "    h_distance = 12734 * asin(sqrt(internal_value))\n",
    "    return data_frame.withColumn('h_distance', h_distance)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d52b062c",
   "metadata": {},
   "source": [
    "* Define main ETL function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9fd36618",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pre_process(data_frame):\n",
    "    processes = [\n",
    "        drop_useless,\n",
    "        encode_categories,\n",
    "        fill_na,\n",
    "        remove_invalid,\n",
    "        convert_datetime,\n",
    "        add_h_distance,\n",
    "    ]\n",
    "    for process in processes:\n",
    "        data_frame = process(data_frame)\n",
    "    return data_frame"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2798f19a",
   "metadata": {},
   "source": [
    "## Run ETL Process and Save the Result\n",
    "* Create Spark Session and create dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "26ca4ca6",
   "metadata": {},
   "outputs": [],
   "source": [
    "spark = (SparkSession\n",
    "    .builder\n",
    "    .appName(\"Taxi-ETL\")\n",
    "    .getOrCreate())\n",
    "reader = (spark\n",
    "        .read\n",
    "        .format('csv'))\n",
    "reader.schema(raw_schema).option('header', 'True')\n",
    "\n",
    "raw_data = reader.load(rawPath)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6243b736",
   "metadata": {},
   "source": [
    "* Run ETL Process and Save the Result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "27f2119b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5.114504098892212\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "etled_train, etled_eval, etled_trans = pre_process(raw_data).randomSplit(list(map(float, (80,20,0))))\n",
    "etled_train.write.mode(\"overwrite\").parquet(outPath+'/train')\n",
    "etled_eval.write.mode(\"overwrite\").parquet(outPath+'/eval')\n",
    "etled_trans.write.mode(\"overwrite\").parquet(outPath+'/trans')\n",
    "end = time.time()\n",
    "print(end - start)\n",
    "spark.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91af3c97",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
